
package com.jstarcraft.ai.jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.clustering.KClustererBase;
import com.jstarcraft.ai.jsat.distributions.kernels.KernelTrick;
import com.jstarcraft.ai.jsat.linear.ConstantVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.parameters.Parameter.ParameterHolder;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;

/**
 * Base class for various Kernel K Means implementations. Because the Kernelized
 * version is more computationally expensive, only the clustering methods where
 * the number of clusters is specified apriori are supported. <br>
 * <br>
 * KernelKMeans keeps a reference to the data passed in for clustering so that
 * queries can be conveniently answered, such as getting
 * {@link #findClosestCluster(com.jstarcraft.ai.jsat.linear.Vec) the closest
 * cluster} or finding the {@link #meanToMeanDistance(int, int) distance between
 * means}
 * 
 * @author Edward Raff
 */
public abstract class KernelKMeans extends KClustererBase implements Parameterized {

    private static final long serialVersionUID = -5294680202634779440L;

    /**
     * The kernel trick to use
     */
    @ParameterHolder
    protected KernelTrick kernel;

    /**
     * The list of data points that this was trained on
     */
    protected List<Vec> X;
    /**
     * The weight of each data point
     */
    protected Vec W;
    /**
     * THe acceleration cache for the kernel
     */
    protected DoubleList accel;
    /**
     * The value of k(x,x) for every point in {@link #X}
     */
    protected double[] selfK;

    /**
     * The value of the un-normalized squared norm for each mean
     */
    protected double[] meanSqrdNorms;

    /**
     * The normalizing constant for each mean. General this would be
     * 1/owned[k]<sup>2</sup>
     */
    protected double[] normConsts;

    /**
     * The weighted number of dataums owned by each mean
     */
    protected double[] ownes;

    /**
     * A temporary space for updating ownership designations for each datapoint.
     * When done, this will store the final designations for each point
     */
    protected int[] newDesignations;
    protected int maximumIterations = Integer.MAX_VALUE;

    /**
     * 
     * @param kernel the kernel to use
     */
    public KernelKMeans(KernelTrick kernel) {
        this.kernel = kernel;
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    public KernelKMeans(KernelKMeans toCopy) {
        this.kernel = toCopy.kernel.clone();
        this.maximumIterations = toCopy.maximumIterations;
        if (toCopy.X != null) {
            this.X = new ArrayList<>(toCopy.X.size());
            for (Vec v : toCopy.X)
                this.X.add(v.clone());

        }
        if (toCopy.accel != null)
            this.accel = new DoubleArrayList(toCopy.accel);
        if (toCopy.selfK != null)
            this.selfK = Arrays.copyOf(toCopy.selfK, toCopy.selfK.length);

        if (toCopy.meanSqrdNorms != null)
            this.meanSqrdNorms = Arrays.copyOf(toCopy.meanSqrdNorms, toCopy.meanSqrdNorms.length);

        if (toCopy.normConsts != null)
            this.normConsts = Arrays.copyOf(toCopy.normConsts, toCopy.normConsts.length);

        if (toCopy.ownes != null)
            this.ownes = Arrays.copyOf(toCopy.ownes, toCopy.ownes.length);

        if (toCopy.newDesignations != null)
            this.newDesignations = Arrays.copyOf(toCopy.newDesignations, toCopy.newDesignations.length);

        if (toCopy.W != null)
            this.W = toCopy.W.clone();
    }

    /**
     * Sets the maximum number of iterations allowed
     * 
     * @param iterLimit the maximum number of iterations of the KMeans algorithm
     */
    public void setMaximumIterations(int iterLimit) {
        if (iterLimit <= 0)
            throw new IllegalArgumentException("iterations must be a positive value, not " + iterLimit);
        this.maximumIterations = iterLimit;
    }

    /**
     * Returns the maximum number of iterations of the KMeans algorithm that will be
     * performed.
     * 
     * @return the maximum number of iterations of the KMeans algorithm that will be
     *         performed.
     */
    public int getMaximumIterations() {
        return maximumIterations;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    /**
     * Computes the kernel sum of data point {@code i} against all the points in
     * cluster group {@code clusterID}.
     * 
     * @param i         the index of the data point to query for
     * @param clusterID the cluster index to get the sum of kernel products
     * @param d
     * @return the sum <big>&Sigma;</big>k(x<sub>i</sub>, x<sub>j</sub>), &forall;
     *         j, d[<i>j</i>] == <i>clusterID</i>
     */
    protected double evalSumK(int i, int clusterID, int[] d) {
        double sum = 0;
        for (int j = 0; j < X.size(); j++)
            if (d[j] == clusterID)
                sum += W.get(j) * kernel.eval(i, j, X, accel);
        return sum;
    }

    /**
     * Computes the kernel sum of the given data point against all the points in
     * cluster group {@code clusterID}.
     * 
     * @param x         the data point to get the kernel sum of
     * @param qi        the query information for the given data point generated
     *                  from the kernel in use. See
     *                  {@link KernelTrick#getQueryInfo(com.jstarcraft.ai.jsat.linear.Vec) }
     * @param clusterID the cluster index to get the sum of kernel products
     * @param d         the array of cluster assignments
     * @return the sum <big>&Sigma;</big>k(x<sub>i</sub>, x<sub>j</sub>), &forall;
     *         j, d[<i>j</i>] == <i>clusterID</i>
     */
    protected double evalSumK(Vec x, DoubleList qi, int clusterID, int[] d) {
        double sum = 0;
        for (int j = 0; j < X.size(); j++)
            if (d[j] == clusterID)
                sum += W.get(j) * kernel.eval(j, x, qi, X, accel);
        return sum;
    }

    /**
     * Sets up the internal structure for KenrelKMeans. Should be called first
     * before any work is done
     * 
     * @param K            the number of clusters to find
     * @param designations the initial designations array to fill with values
     * @param W            the weight for each individual data point
     */
    protected void setup(int K, int[] designations, Vec W) {
        accel = kernel.getAccelerationCache(X);

        final int N = X.size();
        selfK = new double[N];
        for (int i = 0; i < selfK.length; i++)
            selfK[i] = kernel.eval(i, i, X, accel);
        ownes = new double[K];
        meanSqrdNorms = new double[K];
        newDesignations = new int[N];

        if (W == null)
            this.W = new ConstantVector(1.0, N);
        else
            this.W = W;

        Random rand = RandomUtil.getRandom();
        for (int i = 0; i < N; i++) {
            int to = rand.nextInt(K);
            ownes[to] += this.W.get(i);
            newDesignations[i] = designations[i] = to;
        }

        normConsts = new double[K];
        updateNormConsts();

        for (int i = 0; i < N; i++) {
            int i_k = designations[i];
            final double w_i = this.W.get(i);
            meanSqrdNorms[i_k] += w_i * selfK[i];
            for (int j = i + 1; j < N; j++)
                if (i_k == designations[j])
                    meanSqrdNorms[i_k] += 2 * w_i * this.W.get(j) * kernel.eval(i, j, X, accel);
        }
    }

    /**
     * Updates the normalizing constants for each mean. Should be called after every
     * change in ownership
     */
    protected void updateNormConsts() {
        for (int i = 0; i < normConsts.length; i++)
            normConsts[i] = 1.0 / (ownes[i] * ownes[i]);
    }

    /**
     * Computes the distance between one data point and a specified mean
     * 
     * @param i            the data point to get the distance for
     * @param k            the mean index to get the distance to
     * @param designations the array if ownership designations for each cluster to
     *                     use
     * @return the distance between data point {@link #X x}<sub>i</sub> and mean
     *         {@code k}
     */
    protected double distance(int i, int k, int[] designations) {
        return Math.sqrt(Math.max(selfK[i] - 2.0 / ownes[k] * evalSumK(i, k, designations) + meanSqrdNorms[k] * normConsts[k], 0));
    }

    /**
     * Returns the distance between the given data point and the the specified
     * cluster
     * 
     * @param x the data point to get the distance for
     * @param k the cluster id to get the distance to
     * @return the distance between the given data point and the specified cluster
     */
    public double distance(Vec x, int k) {
        return distance(x, kernel.getQueryInfo(x), k);
    }

    /**
     * Returns the distance between the given data point and the the specified
     * cluster
     * 
     * @param x  the data point to get the distance for
     * @param qi the query information for the given data point generated for the
     *           kernel in use. See
     *           {@link KernelTrick#getQueryInfo(com.jstarcraft.ai.jsat.linear.Vec) }
     * @param k  the cluster id to get the distance to
     * @return the distance between the given data point and the specified cluster
     */
    public double distance(Vec x, DoubleList qi, int k) {
        if (k >= meanSqrdNorms.length || k < 0)
            throw new IndexOutOfBoundsException("Only " + meanSqrdNorms.length + " clusters. " + k + " is not a valid index");
        return Math.sqrt(Math.max(kernel.eval(0, 0, Arrays.asList(x), qi) - 2.0 / ownes[k] * evalSumK(x, qi, k, newDesignations) + meanSqrdNorms[k] * normConsts[k], 0));
    }

    /**
     * Finds the cluster ID that is closest to the given data point
     * 
     * @param x the data point to get the closest cluster for
     * @return the index of the closest cluster
     */
    public int findClosestCluster(Vec x) {
        return findClosestCluster(x, kernel.getQueryInfo(x));
    }

    /**
     * Finds the cluster ID that is closest to the given data point
     * 
     * @param x  the data point to get the closest cluster for
     * @param qi the query information for the given data point generated for the
     *           kernel in use. See
     *           {@link KernelTrick#getQueryInfo(com.jstarcraft.ai.jsat.linear.Vec) }
     * @return the index of the closest cluster
     */
    public int findClosestCluster(Vec x, DoubleList qi) {
        double min = Double.MAX_VALUE;
        int min_indx = -1;
        for (int i = 0; i < meanSqrdNorms.length; i++) {
            double dist = distance(x, qi, i);
            if (dist < min) {
                min = dist;
                min_indx = i;
            }
        }

        return min_indx;
    }

    /**
     * Updates the means based off the change of a specific data point
     * 
     * @param i            the index of the data point to try and update the means
     *                     based on its movement
     * @param designations the old assignments for ownership of each data point to
     *                     one of the means
     * @return {@code 1} if the index changed ownership, {@code 0} if the index did
     *         not change ownership
     */
    protected int updateMeansFromChange(int i, int[] designations) {
        return updateMeansFromChange(i, designations, meanSqrdNorms, ownes);
    }

    /**
     * Accumulates the updates to the means and ownership into the provided arrays.
     * This does not update {@link #meanSqrdNorms}, and is meant to accumulate the
     * change. To apply the changes pass the same arrays to
     * {@link #applyMeanUpdates(double[], int[]) }
     * 
     * @param i            the index of the data point to try and update the means
     *                     based on its movement
     * @param designations the old assignments for ownership of each data point to
     *                     one of the means
     * @param sqrdNorms    the array to place the changes to the squared norms in
     * @param ownership    the array to place the changes to the ownership counts in
     * @return {@code 1} if the index changed ownership, {@code 0} if the index did
     *         not change ownership
     */
    protected int updateMeansFromChange(final int i, final int[] designations, final double[] sqrdNorms, final double[] ownership) {
        final int old_d = designations[i];
        final int new_d = newDesignations[i];

        if (old_d == new_d)// this one has not changed!
            return 0;

        final int N = X.size();
        final double w_i = W.get(i);
        ownership[old_d] -= w_i;
        ownership[new_d] += w_i;

        for (int j = 0; j < N; j++) {
            final double w_j = W.get(j);
            final int oldD_j = designations[j];
            final int newD_j = newDesignations[j];
            if (i == j)// diagonal is an easy case
            {
                sqrdNorms[old_d] -= w_i * selfK[i];
                sqrdNorms[new_d] += w_i * selfK[i];
            } else {
                // handle removing contribution from old mean
                if (old_d == oldD_j) {
                    // only do this for items that were apart of the OLD center

                    if (i > j && oldD_j != newD_j) {
                        /*
                         * j,j is also being removed from this center. To avoid removing the value k_ij
                         * twice, the person with the later index gets to do the update
                         */
                    } else// safe to remove the k_ij contribution
                        sqrdNorms[old_d] -= 2 * w_i * w_j * kernel.eval(i, j, X, accel);
                }
                // handle adding contributiont to new mean
                if (new_d == newD_j) {
                    // only do this for items that are apart of the NEW center

                    if (i > j && oldD_j != newD_j) {
                        /*
                         * j,j is also being added to this center. To avoid adding the value k_ij twice,
                         * the person with the later index gets to do the update
                         */
                    } else
                        sqrdNorms[new_d] += 2 * w_i * w_j * kernel.eval(i, j, X, accel);
                }
            }
        }

        return 1;
    }

    protected void applyMeanUpdates(double[] sqrdNorms, double[] ownerships) {
        for (int i = 0; i < sqrdNorms.length; i++) {
            meanSqrdNorms[i] += sqrdNorms[i];
            ownes[i] += ownerships[i];
        }
    }

    /**
     * Computes the distance between two of the means in the clustering
     * 
     * @param k0 the index of the first mean
     * @param k1 the index of the second mean
     * @return the distance between the two
     */
    public double meanToMeanDistance(int k0, int k1) {
        if (k0 >= meanSqrdNorms.length || k0 < 0)
            throw new IndexOutOfBoundsException("Only " + meanSqrdNorms.length + " clusters. " + k0 + " is not a valid index");
        if (k1 >= meanSqrdNorms.length || k1 < 0)
            throw new IndexOutOfBoundsException("Only " + meanSqrdNorms.length + " clusters. " + k1 + " is not a valid index");

        return meanToMeanDistance(k0, k1, newDesignations);
    }

    protected double meanToMeanDistance(int k0, int k1, int[] assignments) {
        double d = meanSqrdNorms[k0] * normConsts[k0] + meanSqrdNorms[k1] * normConsts[k1] - 2 * dot(k0, k1, assignments);
        return Math.sqrt(Math.max(0, d));// Avoid rare cases wehre 2*dot might be slightly larger
    }

    protected double meanToMeanDistance(int k0, int k1, int[] assignments, boolean parallel) {
        double d = meanSqrdNorms[k0] * normConsts[k0] + meanSqrdNorms[k1] * normConsts[k1] - 2 * dot(k0, k1, assignments, parallel);
        return Math.sqrt(Math.max(0, d));// Avoid rare cases wehre 2*dot might be slightly larger
    }

    /**
     * 
     * @param k0           the index of the first cluster
     * @param k1           the index of the second cluster
     * @param assignments0 the array of assignments to use for index k0
     * @param assignments1 the array of assignments to use for index k1
     * @param k1SqrdNorm   the <i>normalized</i> squared norm for the mean indicated
     *                     by {@code k1}. (ie: {@link #meanSqrdNorms} multiplied by
     *                     {@link #normConsts}
     * @return
     */
    protected double meanToMeanDistance(int k0, int k1, int[] assignments0, int[] assignments1, double k1SqrdNorm) {
        double d = meanSqrdNorms[k0] * normConsts[k0] + k1SqrdNorm - 2 * dot(k0, k1, assignments0, assignments1);
        return Math.sqrt(Math.max(0, d));// Avoid rare cases wehre 2*dot might be slightly larger
    }

    /**
     * 
     * @param k0           the index of the first cluster
     * @param k1           the index of the second cluster
     * @param assignments0 the array of assignments to use for index k0
     * @param assignments1 the array of assignments to use for index k1
     * @param k1SqrdNorm   the <i>normalized</i> squared norm for the mean indicated
     *                     by {@code k1}. (ie: {@link #meanSqrdNorms} multiplied by
     *                     {@link #normConsts}
     * @param parallel     source of threads for parallel execution
     * @return
     */
    protected double meanToMeanDistance(int k0, int k1, int[] assignments0, int[] assignments1, double k1SqrdNorm, boolean parallel) {
        double d = meanSqrdNorms[k0] * normConsts[k0] + k1SqrdNorm - 2 * dot(k0, k1, assignments0, assignments1, parallel);
        return Math.sqrt(Math.max(0, d));// Avoid rare cases wehre 2*dot might be slightly larger
    }

    /**
     * dot product between two different clusters from one set of cluster
     * assignments
     * 
     * @param k0         the index of the first cluster
     * @param k1         the index of the second cluster
     * @param assignment the array of assignments for cluster ownership
     * @return the dot product between the two clusters.
     */
    private double dot(final int k0, final int k1, final int[] assignment) {
        return dot(k0, k1, assignment, assignment);
    }

    /**
     * dot product between two different clusters from one set of cluster
     * assignments
     * 
     * @param k0         the index of the first cluster
     * @param k1         the index of the second cluster
     * @param assignment the array of assignments for cluster ownership
     * @param parallel   source of threads for parallel execution
     * @return the dot product between the two clusters.
     */
    private double dot(final int k0, final int k1, final int[] assignment, boolean parallel) {
        return dot(k0, k1, assignment, assignment, parallel);
    }

    /**
     * dot product between two different clusters from different sets of cluster
     * assignments. Two different assignment arrays are used to allow overlapping
     * assignment of points to the clusters.
     *
     * @param k0          the first cluster to take the dot product with
     * @param k1          the second cluster to take the dot product with.
     * @param assignment0 vector containing assignment values, will be used to
     *                    determine which points belong to k0
     * @param assignment1 vector containing assignment values, will be used to
     *                    determine which points belong to k1
     * @return the dot product between the two clusters.
     */
    private double dot(final int k0, final int k1, final int[] assignment0, final int[] assignment1) {
        double dot = 0;
        final int N = X.size();
        double a = 0, b = 0;
        /*
         * Below, unless i&amp;j are somehow in the same cluster - nothing bad will
         * happen
         */
        for (int i = 0; i < N; i++) {
            final double w_i = W.get(i);
            if (assignment0[i] != k0)
                continue;
            a += w_i;
            for (int j = 0; j < N; j++) {
                if (assignment1[j] != k1)
                    continue;
                final double w_j = W.get(j);
                dot += w_i * w_j * kernel.eval(i, j, X, accel);
            }
        }
        for (int j = 0; j < N; j++)
            if (assignment1[j] == k1)
                b += W.get(j);
        return dot / (a * b);
    }

    /**
     * dot product between two different clusters from different sets of cluster
     * assignments. Two different assignment arrays are used to allow overlapping
     * assignment of points to the clusters.
     *
     * @param k0          the first cluster to take the dot product with
     * @param k1          the second cluster to take the dot product with.
     * @param assignment0 vector containing assignment values, will be used to
     *                    determine which points belong to k0
     * @param assignment1 vector containing assignment values, will be used to
     *                    determine which points belong to k1
     * @param ex          source of threads for parallel execution
     * @return the dot product between the two clusters.
     */
    private double dot(final int k0, final int k1, final int[] assignment0, final int[] assignment1, boolean parallel) {
        double dot = 0;
        final int N = X.size();
        double a = 0, b = 0;

        /*
         * Below, unless i&amp;j are somehow in the same cluster - nothing bad will
         * happen
         */
        ParallelUtils.run(parallel, N, (i) -> {
            final double w_i = W.get(i);
            if (assignment0[i] != k0)
                return 0.0;
            double localDot = 0;
            for (int j = 0; j < N; j++) {
                if (assignment1[j] != k1)
                    continue;
                final double w_j = W.get(j);
                localDot += w_i * w_j * kernel.eval(i, j, X, accel);
            }
            return localDot;
        }, (t, u) -> t + u);

        a = W.sum();

        for (int j = 0; j < N; j++)
            if (assignment1[j] == k1)
                b += W.get(j);

        return dot / (a * b);
    }

    @Override
    abstract public KernelKMeans clone();

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

}
